import pathlib

import matplotlib.pyplot
import torch

from lotka_volterra_problem import LotkaVolterraProblem
from lotka_volterra_dataset import LotkaVolterraDataset
import torch_plus


def main():
    output_path = pathlib.Path("./checkpoints/data.pt").resolve()

    random = torch.Generator()
    random.manual_seed(1234)

    dataset = LotkaVolterraDataset(
        LotkaVolterraProblem(2, 0.01, 0.005, 1, 300, 150),
        torch_plus.rand(20, 100, 110, random),
        torch_plus.rand(20, 100, 110, random))
    dataset.save(output_path)
    print(f"The generated data have been saved to: {output_path}")

    matplotlib.pyplot.scatter(dataset.t1(), dataset.x1(), label="$x_1$")
    matplotlib.pyplot.scatter(dataset.t2(), dataset.x2(), label="$x_2$")
    matplotlib.pyplot.legend()

    matplotlib.pyplot.show()


if __name__ == "__main__":
    main()
